Kestrel-3

PLA-type Instruction Decoder
Login
# This file is a case-study on how to write a PLA-style instruction
# decoder using nmigen.  Such decoders are found at the heart of most
# 70s and 80s-era CPUs, such as the 6502, 65816, Z80A, and TMS9900.
#
# The KCP53000 processor was written in hand-crafted Verilog,
# augmented with code generated from a tool which I dubbed State
# Machine Generator, or SMG.  SMG was written in Shen Lisp, and
# emitted equivalent Verilog.  The input was designed to resemble (to
# a reasonable extent) a state table, such as what you'd find in a TTL
# chip data-sheet.
#
# Writing a similar tool for nmigen seems to be superfluous, as nmigen
# has better semantics than raw Verilog.  This allows us to directly
# encode the PLA logic using more conventional constructs.
#
# The nmigen approach is still somewhat less convenient than SMG,
# because it is more verbose.  However, it has the distinct advantage
# that we also don't need to repeat conditions for each minterm.
#
# The "instruction decoder" in this example is deliberately kept
# simple (only three instructions) but at the same time somewhat
# realistic for an 8-bit processor design (the SRC instruction has a
# mode which requires memory access, and so, uses more than the
# average number of execution cycles).
#
# This example DOES NOT show how to implement overlapping states.  In
# fact, they're expressly forbidden (see the formal verification rules
# in the PLAFormal class).  If you need overlapping states, then you
# can implement them in different "classes" of states.  For example,
# the execution engine have states named X0, X1, X2, and so on.  An
# instruction fetch unit might have states named I0, I1, I2, etc.
# States in the I-class and X-class can run concurrently; however, no
# two X-states or no two I-states can exist at the same time.
#
# Regarding states, each state has two signals in the module's
# interface: a Q-signal and a D-signal.  The D-signals are outputs
# from this module, to be inputs to D flip flops.  The Q outputs
# of each flip flop is inputs to the decoder module.  For instance,
# if there are three X states, then there will be three QX inputs
# and three DX outputs.
#
#     +------------------+
#     |   Decoder Logic  |
#     |                  |
#     |  dxN         qxN |
#     +------------------+
#        |            ^
#        |  +------+  |
#        +--| D  Q |--+
#           |      |
#           +------+
#              DFF
#
# To invoke the tests in this file, simply run:
#
#     python -m unittest test-pla
#
# Assuming you have nmigen and yosys (and their respective
# dependencies) installed, after a few seconds, the tests should
# return successfully.


from nmigen.test.tools import FHDLTestCase
from nmigen.back.pysim import Simulator, Delay
from nmigen import (
    Signal, Module, Elaboratable, Const,
)
from nmigen.hdl.ast import Assert, Assume


def create_interface(self):
    """
    Create the interface signals that are common to both the unit
    under test and to the test driver modules.
    """

    # State machine and next state signals
    self.opcode = Signal(8)
    self.qx0 = Signal()
    self.qx1 = Signal()
    self.qx2 = Signal()

    self.dx0 = Signal()
    self.dx1 = Signal()
    self.dx2 = Signal()

    # Decoded sequencer outputs
    self.rs_to_regsel = Signal()
    self.rd_to_regsel = Signal()
    self.reg_write = Signal()
    self.rs_to_t = Signal()
    self.reg_to_t = Signal()
    self.t_write = Signal()
    self.reg_to_addr = Signal()
    self.bus_read = Signal()
    self.bus_ready = Signal()
    self.bus_to_t = Signal()

    # Formal verification-only; provides glass-box visibility.
    self.fv_mov_insn = Signal()
    self.fv_alu_insn = Signal()
    self.fv_src_insn = Signal()


class PLA(Elaboratable):
    def __init__(self):
        super().__init__()
        create_interface(self)

    def elaborate(self, platform):
        m = Module()
        sync = m.d.sync
        comb = m.d.comb

        is_mov = Signal()
        is_alu = Signal()
        is_src = Signal()

        comb += [
            is_mov.eq(self.opcode[6:8] == Const(0, 2)),
            is_alu.eq(self.opcode[6:8] == Const(1, 2)),
            is_src.eq(self.opcode[6:8] == Const(2, 2)),

            self.fv_mov_insn.eq(is_mov),
            self.fv_alu_insn.eq(is_alu),
            self.fv_src_insn.eq(is_src),
        ]

        comb += [
            self.rs_to_regsel.eq(0),
            self.rd_to_regsel.eq(0),
            self.reg_write.eq(0),
            self.rs_to_t.eq(0),
            self.reg_to_t.eq(0),
            self.t_write.eq(0),
            self.reg_to_addr.eq(0),
            self.bus_read.eq(0),
            self.bus_to_t.eq(0),
            self.dx0.eq(0),
            self.dx1.eq(0),
            self.dx2.eq(0),
        ]

        with m.If(is_mov):
            with m.If(self.qx0):
                comb += [
                    self.rs_to_regsel.eq(1),
                    self.dx1.eq(1),
                ]
            with m.If(self.qx1):
                comb += [
                    self.rd_to_regsel.eq(1),
                    self.reg_write.eq(1),
                    self.dx0.eq(1),
                ]

        with m.If(is_alu):
            with m.If(self.qx0):
                comb += [
                    self.rd_to_regsel.eq(1),
                    self.dx1.eq(1),
                ]
            with m.If(self.qx1):
                comb += [
                    self.rd_to_regsel.eq(1),
                    self.reg_write.eq(1),
                    self.dx0.eq(1),
                ]

        with m.If(is_src):
            with m.If(self.opcode[0:3] == Const(0, 3)):
                with m.If(self.qx0):
                    comb += [
                        self.rs_to_t.eq(1),
                        self.t_write.eq(1),
                        self.dx0.eq(1),
                    ]
            with m.If(self.opcode[0:3] == Const(1, 3)):
                with m.If(self.qx0):
                    comb += [
                        self.rs_to_regsel.eq(1),
                        self.dx1.eq(1),
                    ]
                with m.If(self.qx1):
                    comb += [
                        self.reg_to_t.eq(1),
                        self.t_write.eq(1),
                        self.dx0.eq(1),
                    ]
            with m.If(self.opcode[0:3] == Const(2, 3)):
                with m.If(self.qx0):
                    comb += [
                        self.rs_to_regsel.eq(1),
                        self.dx1.eq(1),
                    ]
                with m.If(self.qx1):
                    comb += [
                        self.reg_to_addr.eq(1),
                        self.bus_read.eq(1),
                    ]
                    with m.If(self.bus_ready):
                        comb += self.dx2.eq(1)
                    with m.If(~self.bus_read):
                        comb += self.dx1.eq(1)
                with m.If(self.qx2):
                    comb += [
                        self.bus_to_t.eq(1),
                        self.t_write.eq(1),
                        self.dx0.eq(1),
                    ]

        return m


class PLAFormal(Elaboratable):
    """A test to see how I can implement PLA-style logic
    without having to write a code generator.

    With this simple example, we decode instructions from a
    hypothetical CISCy processor.  The instructions are:

    00sssddd - MOV Rs, Rd
    01oooddd - (ADD, SUB, CMP, AND, OR, XOR, LDR, STR) Rd
    10sssmmm - source operand load into temporary register.
               0: 3-bit "quick" (unsigned) immediate
               1: register direct
               2: register indirect
               3-7: unused
    11------ - Unused.
    """

    def __init__(self):
        super().__init__()
        create_interface(self)

    def elaborate(self, platform):
        m = Module()
        sync = m.d.sync
        comb = m.d.comb

        dut = PLA()
        m.submodules.dut = dut
        comb += [
            dut.opcode.eq(self.opcode),
            dut.qx0.eq(self.qx0),
            dut.qx1.eq(self.qx1),
            dut.qx2.eq(self.qx2),

            self.rs_to_regsel.eq(dut.rs_to_regsel),
            self.rd_to_regsel.eq(dut.rd_to_regsel),
            self.reg_write.eq(dut.reg_write),
            self.rs_to_t.eq(dut.rs_to_t),
            self.reg_to_t.eq(dut.reg_to_t),
            self.t_write.eq(dut.t_write),

            self.reg_to_addr.eq(dut.reg_to_addr),
            self.bus_read.eq(dut.bus_read),
            dut.bus_ready.eq(self.bus_ready),
            self.bus_to_t.eq(dut.bus_to_t),

            self.dx0.eq(dut.dx0),
            self.dx1.eq(dut.dx1),
            self.dx2.eq(dut.dx2),

            self.fv_mov_insn.eq(dut.fv_mov_insn),
            self.fv_alu_insn.eq(dut.fv_alu_insn),
            self.fv_src_insn.eq(dut.fv_src_insn),
        ]

        with m.If(self.opcode[6:8] == Const(0, 2)):
            comb += Assert(self.fv_mov_insn & ~self.fv_alu_insn & ~self.fv_src_insn)
        with m.If(self.opcode[6:8] == Const(1, 2)):
            comb += Assert(~self.fv_mov_insn & self.fv_alu_insn & ~self.fv_src_insn)
        with m.If(self.opcode[6:8] == Const(2, 2)):
            comb += Assert(~self.fv_mov_insn & ~self.fv_alu_insn & self.fv_src_insn)
        with m.If(self.opcode[6:8] == Const(3, 2)):
            comb += Assert(~self.fv_mov_insn & ~self.fv_alu_insn & ~self.fv_src_insn)

        # States in the same category must be non-overlapping.
        with m.If(self.qx0):
            comb += [
                Assume(~self.qx1),
            ]
        with m.If(self.qx1):
            comb += [
                Assume(~self.qx0),
            ]

        # Instruction decoder tests.
        with m.If(self.fv_mov_insn):
            with m.If(self.qx0):
                comb += [
                    Assert(self.rs_to_regsel),
                    Assert(~self.reg_write),
                    Assert(self.dx1),
                ]
            with m.If(self.qx1):
                comb += [
                    Assert(self.rd_to_regsel),
                    Assert(self.reg_write),
                    Assert(self.dx0),
                ]

        with m.If(self.fv_alu_insn):
            with m.If(self.qx0):
                comb += [
                    Assert(self.rd_to_regsel),
                    Assert(~self.reg_write),
                    Assert(self.dx1),
                ]
            with m.If(self.qx1):
                comb += [
                    Assert(self.rd_to_regsel),
                    Assert(self.reg_write),
                    Assert(self.dx0),
                ]
        with m.If(self.fv_src_insn):
            with m.If(self.opcode[0:3] == Const(0, 3)):
                with m.If(self.qx0):
                    comb += [
                        Assert(self.t_write),
                        Assert(self.rs_to_t),
                        Assert(self.dx0),
                    ]
            with m.If(self.opcode[0:3] == Const(1, 3)):
                with m.If(self.qx0):
                    comb += [
                        Assert(self.rs_to_regsel),
                        Assert(self.dx1),
                    ]
                with m.If(self.qx1):
                    comb += [
                        Assert(self.reg_to_t),
                       Assert(self.t_write),
                        Assert(self.dx0),
                    ]
            with m.If(self.opcode[0:3] == Const(2, 3)):
                with m.If(self.qx0):
                    comb += [
                        Assert(self.rs_to_regsel),
                        Assert(self.dx1),
                    ]
                with m.If(self.qx1):
                    comb += [
                        Assert(self.reg_to_addr),
                        Assert(self.bus_read),
                    ]
                    with m.If(self.bus_ready):
                        comb += Assert(self.dx2)
                    with m.If(~self.bus_read):
                        comb += Assert(self.dx1)
                with m.If(self.qx2):
                    comb += [
                        Assert(self.bus_to_t),
                        Assert(self.t_write),
                        Assert(self.dx0),
                    ]

        return m


class PLATestCase(FHDLTestCase):

    def test_pla_formally(self):
        self.assertFormal(PLAFormal(), mode='bmc', depth=100)